package etcdstore

import (
	"context"
	"errors"
	"fmt"
	"gitee.com/zackeus/go-boot/idgen/store"
	"gitee.com/zackeus/go-boot/tools/errorx"
	"gitee.com/zackeus/go-zero/core/logx"
	"gitee.com/zackeus/go-zero/core/threading"
	clientv3 "go.etcd.io/etcd/client/v3"
	"go.etcd.io/etcd/client/v3/concurrency"
	"google.golang.org/grpc/connectivity"
	"strconv"
	"sync"
	"sync/atomic"
	"time"
)

const (
	defaultPfx     = "id/worker/"
	defaultLockKey = "id/lock"
	defaultTTl     = 30
)

var (
	defaultTimeout = 3 * time.Second
)

type (
	// Options 可选项
	Options func(s *etcdStore)

	// etcdStore session 存储(基于 redis)
	etcdStore struct {
		once                sync.Once
		closed              atomic.Bool
		ctx                 context.Context
		stopFunc            func()
		workId              atomic.Uint64
		maxWorkerId         uint16 // workId 最大值
		client              *clientv3.Client
		leaseID             clientv3.LeaseID // 租约ID
		leaseAlive          atomic.Bool      // 租约是否存活
		reKeepAliveInterval time.Duration    // 重新续约间隔
		lockKey             string           // 分布式锁key
		prefix              string           // 前缀
		ttl                 int64
		timeout             time.Duration // 初始化超时
	}
)

func New(client *clientv3.Client, maxWorkerId uint16, options ...Options) (store.IdStore, error) {
	/* 构建上下文 */
	ctx, stop := context.WithCancel(context.Background())
	s := &etcdStore{
		ctx:                 ctx,
		stopFunc:            stop,
		maxWorkerId:         maxWorkerId,
		client:              client,
		reKeepAliveInterval: 3 * time.Second,
		lockKey:             defaultLockKey,
		prefix:              defaultPfx,
		ttl:                 defaultTTl,
		timeout:             defaultTimeout,
	}
	for _, option := range options {
		option(s)
	}
	s.closed.Store(false)
	s.leaseAlive.Store(false)

	/* 初始化 */
	if err := s.init(); err != nil {
		return nil, err
	}
	/* 注册租约 */
	if err := s.register(); err != nil {
		return nil, err
	}
	/* 租约续期 */
	if err := s.keepAlive(); err != nil {
		s.revoke()
		return nil, err
	}

	return s, nil
}

// etcd 存储 key
func (s *etcdStore) storeKey(workId uint16) string {
	return fmt.Sprintf("%s%v", s.prefix, workId)
}

// 初始化
func (s *etcdStore) init() error {
	/* 构建 session */
	session, err := concurrency.NewSession(s.client, concurrency.WithTTL(int(s.ttl)))
	if err != nil {
		return err
	}
	defer func(session *concurrency.Session) {
		_ = session.Close()
	}(session)
	mu := concurrency.NewMutex(session, s.lockKey)
	/* 加锁 */
	if err := mu.Lock(s.ctx); err != nil {
		return err
	}
	/* 解锁 */
	defer func(mu *concurrency.Mutex, ctx context.Context) {
		_ = mu.Unlock(ctx)
	}(mu, s.ctx)

	/* etcd 全局 key */
	globalKey := fmt.Sprintf("%sglobal", s.prefix)
	/* 定义事务 先put 再 get */
	txResp, err := s.client.Txn(s.ctx).
		Then(clientv3.OpPut(globalKey, strconv.FormatInt(time.Now().UnixMilli(), 10)), clientv3.OpGet(globalKey)).
		Commit()
	if err != nil {
		return err
	}
	/* 事务操作失败 */
	if !txResp.Succeeded {
		return errors.New("the workId index tx failed")
	}

	/* 使用 version 作为全局递增标识 并与 maxWorkerId 取模运算  */
	currentWorkId := uint16(txResp.Responses[1].GetResponseRange().GetKvs()[0].Version) % s.maxWorkerId
	logx.Debugf("the init workId is [%v]", currentWorkId)

	/* 获取所有前缀为 s.prefix 的节点 */
	resp, err := s.client.Get(s.ctx, s.prefix, clientv3.WithPrefix())
	if err != nil {
		return err
	}

	/* 定义一个map，保留已存在的节点 */
	nodes := make(map[uint16]uint16)
	for _, ev := range resp.Kvs {
		num, err := strconv.ParseUint(string(ev.Value), 10, 64)
		if err != nil {
			/* 键值不是int 抛出异常 */
			return err
		}
		/* 键值存入 map */
		nodes[uint16(num)] = uint16(num)
	}

	/* 从 maxWorkerId 找最小的 workId */
	var count uint16
	for count = 0; count <= s.maxWorkerId; count++ {
		if _, ok := nodes[currentWorkId]; !ok {
			break
		}
		/* 每次+1取余运算 */
		currentWorkId = (currentWorkId + 1) % s.maxWorkerId
	}
	/* 计算次数大于 maxWorkerId */
	if count > s.maxWorkerId {
		return errors.New(fmt.Sprintf("the workId exceeds maximum value [%v] ", s.maxWorkerId))
	}
	logx.Debugf("the currentWorkId is [%v]", currentWorkId)

	/* workId 赋值 */
	s.workId.Store(uint64(currentWorkId))
	return nil
}

// 注册租约
func (s *etcdStore) register() error {
	lease := clientv3.NewLease(s.client)
	/* 声明租约 */
	leaseGrant, err := lease.Grant(s.ctx, s.ttl)
	if err != nil {
		return err
	}

	key := s.storeKey(uint16(s.workId.Load()))
	/* 使用事务 先检查 workId 是否被占用 再绑定租约 */
	txResp, err := s.client.Txn(s.ctx).
		If(clientv3.Compare(clientv3.LeaseValue(key), "=", 0)).
		Then(clientv3.OpPut(key, strconv.FormatUint(s.workId.Load(), 10), clientv3.WithLease(leaseGrant.ID))).
		Commit()
	if err != nil {
		return err
	}
	/* 事务操作失败 当前 workId 被其他服务占用  */
	if !txResp.Succeeded {
		return errors.New(fmt.Sprintf("the workId [%v] reuse.", s.workId.Load()))
	}

	s.leaseID = leaseGrant.ID
	return nil
}

func (s *etcdStore) keepAlive() error {
	/* 自动续期 */
	ch, err := s.client.KeepAlive(s.ctx, s.leaseID)
	if err != nil {
		return err
	}
	s.leaseAlive.Store(true)

	threading.GoSafe(func() {
		for {
			select {
			case _, ok := <-ch:
				if !ok {
					/* 租约过期 可能网络连接断开 进行续期 */
					s.leaseAlive.Store(false)
					s.revoke()
					if err := s.reKeepAlive(); err != nil {
						logx.Errorf("idgen etcd store KeepAlive: %s", err.Error())
					}
					return
				}
			case <-s.ctx.Done():
				return
			}
		}
	})
	return nil
}

// 自动续期
func (s *etcdStore) reKeepAlive() error {
	if s.closed.Load() {
		return nil
	}

	ticker := time.NewTicker(s.reKeepAliveInterval)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			/* 连接已关闭 */
			if s.client == nil {
				return nil
			}
			conn := s.client.ActiveConnection()
			if conn == nil || conn.GetState() == connectivity.Shutdown {
				return nil
			}

			if err := s.register(); err != nil {
				logx.Error(errorx.Wrap(err, "reKeepAlive failed, will try again."))
				break
			}
			if err := s.keepAlive(); err != nil {
				logx.Error(errorx.Wrap(err, "reKeepAlive failed, will try again."))
				s.revoke()
				break
			}

			logx.Debugf("the workId [%v] reKeepAlive success", s.workId.Load())
			return nil
		case <-s.ctx.Done():
			return nil
		}
	}
}

func (s *etcdStore) revoke() {
	if s.client == nil || s.closed.Load() {
		return
	}
	if _, err := s.client.Revoke(s.ctx, s.leaseID); err != nil {
		logx.Errorf("idgen etcd store revoke: %s", err.Error())
	}
}

// Available 判断ID生成器存储是否可用
func (s *etcdStore) Available() bool {
	if s.closed.Load() {
		logx.Error("the id store has closed.")
		return false
	}
	if !s.leaseAlive.Load() {
		/* 租约无效 */
		logx.Alert(fmt.Sprintf("the workId [%v] lease not alive.", s.workId.Load()))
		return false
	}
	return true
}

// GetWorkerId 获取workID
func (s *etcdStore) GetWorkerId() uint16 {
	return uint16(s.workId.Load())
}

func (s *etcdStore) Shutdown(g bool) {
	s.once.Do(func() {
		/* 先标识 */
		s.closed.Store(true)

		s.revoke()
		if s.client != nil && g {
			_ = s.client.Close()
		}

		s.client = nil
		s.stopFunc()
	})
}

// WithTimeout 初始化超时
func WithTimeout(t time.Duration) Options {
	return func(s *etcdStore) {
		s.timeout = t
	}
}

// WithLockKey 分布式锁key
func WithLockKey(k string) Options {
	return func(s *etcdStore) {
		s.lockKey = k
	}
}

// WithPrefix etcd 注册前缀
func WithPrefix(p string) Options {
	return func(s *etcdStore) {
		s.prefix = p
	}
}

// WithTTL etcd 租约周期
func WithTTL(n int64) Options {
	return func(s *etcdStore) {
		s.ttl = n
	}
}

// WithKeepAliveInterval 自动续期间隔
func WithKeepAliveInterval(t time.Duration) Options {
	return func(s *etcdStore) {
		s.reKeepAliveInterval = t
	}
}
